package com.pactera.madp.filter;

import cn.hutool.core.io.IoUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.digest.DigestAlgorithm;
import cn.hutool.crypto.digest.Digester;
import cn.hutool.http.HttpStatus;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import com.pactera.madp.constant.SecurityConstants;
import com.pactera.madp.exception.AuthException;
import com.pactera.madp.util.R;
import com.pactera.madp.util.WebUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.HttpMethod;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;

/**
 * All rights Reserved, Designed By teana@sina.cn.
 *
 * @Author: Deming.Chang
 * @Email: teana@sina.cn
 * @Date: 2018/9/20 上午12:57
 * @Version: 1.0.0
 * @Copyright: 2018
 * @Description:
 */
@Slf4j
public class AuthFilter implements Filter {

    private RedisTemplate redisTemplate;
    private long noRepeatTime;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        ServletContext servletContext = filterConfig.getServletContext();
        WebApplicationContext wac = WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext);
        redisTemplate = (RedisTemplate) wac.getBean("redisTemplate");

        String noRepeatTimeString =  filterConfig.getInitParameter("noRepeatTime");
        noRepeatTime = Long.parseLong(noRepeatTimeString);
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest)servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;
        KoalaHttpRequestWrapper requestWrapper = new KoalaHttpRequestWrapper(request);

        if(StrUtil.startWith(request.getRequestURI(), "/common")){
            filterChain.doFilter(servletRequest, servletResponse);
            return;
        }

        try {
            // 1.0 获取请求头参数
            String timestamp = request.getHeader("Madp-Timestamp");
            String nonce = request.getHeader("Madp-Nonce");
            String sign = request.getHeader("Madp-Sign");

            Assert.notNull(timestamp, "[Assertion failed] - Madp-Timestamp header is required; it must not be null");
            Assert.notNull(nonce, "[Assertion failed] - Madp-Nonce header is required; it must not be null");
            Assert.notNull(sign, "[Assertion failed] - Madp-Sign header is required; it must not be null");

            Assert.isTrue(StrUtil.isNotBlank(timestamp), "[Assertion failed] - Madp-Timestamp header is required; it must not be empty");
            Assert.isTrue(StrUtil.isNotBlank(nonce), "[Assertion failed] - Madp-Nonce header is required; it must not be empty");
            Assert.isTrue(StrUtil.isNotBlank(sign), "[Assertion failed] - Madp-Sign header is required; it must not be empty");
            Assert.isTrue(isTimestamp(timestamp), "[Assertion failed] - Madp-Timestamp header must be timestamp");

            // 2.0 验证完整性-->请求参数防篡改
            String signStr = createSign(requestWrapper, response, sign);
            Assert.isTrue(StrUtil.equals(sign, signStr), "[Assertion failed] - Request parameters had been tampered");

            // 3.0 验证请求是否有效-->防止盗用链接
            Assert.isTrue(validRequest(timestamp), "[Assertion failed] - Request illegal");

            // 4.0 验证请求是否重复-->防重放
            Assert.isTrue(validRepeatRequest(nonce, timestamp), "[Assertion failed] - Request duplicate");

        }catch (Exception e){
            response.setStatus(HttpStatus.HTTP_BAD_REQUEST);
            log.error("{}", e.getMessage(), e);
            WebUtils.renderJson(response, R.error(e.getMessage()));
            return;
        }
        if(requestWrapper != null){
            filterChain.doFilter(requestWrapper, servletResponse);
        }else {
            filterChain.doFilter(servletRequest, servletResponse);
        }
    }

    @Override
    public void destroy() {

    }

    /**
     *
     * 判断字符串是否为时间戳
     *
     * @param timestamp 字符串时间戳
     * @return
     */
    private static boolean isTimestamp(String timestamp){
        if(NumberUtil.isLong(timestamp) && timestamp.length() == 13){
            try{
                new Date(Long.valueOf(timestamp));
                return  true;
            }catch (Exception e){
                return false;
            }
        }
        return false;
    }

    /**
     *  对请求参数进行数据签名处理
     * @param requestWrapper
     * @param response
     * @return
     * @throws IOException
     */
    private String createSign(KoalaHttpRequestWrapper requestWrapper,
                              HttpServletResponse response,
                              String sign) throws IOException {

        Map<String, Object> map = new HashMap<>();

        // 1.0 从body获取JSON参数(POST|PUT)
        if(StrUtil.equalsIgnoreCase(requestWrapper.getMethod(), HttpMethod.POST.name()) ||
                StrUtil.equalsIgnoreCase(requestWrapper.getMethod(), HttpMethod.PUT.name())){
            InputStream is = requestWrapper.getInputStream();
            String body = IoUtil.read(is, "UTF-8");
            if(StrUtil.isNotBlank(body)){
                if(JSONUtil.isJson(body)){
                    map = JSONUtil.toBean(body, Map.class);
                }else{
                    response.setStatus(HttpStatus.HTTP_BAD_REQUEST);
                    throw new AuthException(AuthException.REQUEST_BODY_NOT_JSON);
                }
            }
        }

        // 2.0 获取参数(GET|DELETE)
        if(StrUtil.equalsIgnoreCase(requestWrapper.getMethod(), HttpMethod.GET.name()) ||
                StrUtil.equalsIgnoreCase(requestWrapper.getMethod(), HttpMethod.DELETE.name())){
            Map<String, String[]> parameterMap = requestWrapper.getParameterMap();
            if(MapUtil.isNotEmpty(parameterMap)){
                map.putAll(parameterMap);
            }
        }

        // 3.0 组装签名数据
        StringBuffer signMap = requestWrapper.getRequestURL();
        if(MapUtil.isNotEmpty(map)){
            Map<String, String> params = new HashMap<>();
            for(String key: map.keySet()){
                String value = StrUtil.join(",",  map.get(key));
                params.put(key, value);
            }
            if(!params.isEmpty()){
                signMap.append("?");
                if(params.size() > 1){
                    TreeMap<String, String> treeMap = MapUtil.sort(params);
                    signMap.append(HttpUtil.toParams(treeMap));
                }else{
                    signMap.append(HttpUtil.toParams(params));
                }
            }
        }

        // 4.0 对请求参数进行签名
        Digester md5 = new Digester(DigestAlgorithm.MD5);
        String signStr = md5.digestHex(signMap.toString());
        log.info("Madp-Sign = {}", sign);
        log.info("Sign-Map = {}", signMap);
        log.info("Sign-Str = {}", signStr);

        return signStr;
    }

    /**
     *  验证请求是否在有效期内
     * @param timestamp         客户端时间戳
     * @return
     */
    private boolean validRequest(String timestamp){
        long now = System.currentTimeMillis();
        long client = Long.parseLong(timestamp);
        if(now - client > noRepeatTime){
            return  false;
        }
        return  true;
    }

    /**
     * 验证请求是否重复（防重放）
     * @param nonce         客户端随机数
     * @param timestamp     客户端时间戳
     * @return
     */
    private boolean validRepeatRequest(String nonce, String timestamp){
        String once = (String) redisTemplate.opsForValue().get(SecurityConstants.NONCE_PREFIX + nonce);
        if(StrUtil.isNotBlank(once)){
            return false;
        }
        String expire = (timestamp + noRepeatTime) + "";
        redisTemplate.opsForValue().set(SecurityConstants.NONCE_PREFIX + nonce, expire, noRepeatTime, TimeUnit.MILLISECONDS);
        return true;
    }
}
